{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pyspiel \n",
    "import tensorflow.compat.v1 as tf\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "import algorithms.rcfr as rcfr_tf\n",
    "import pytorch.rcfr as rcfr_pt\n",
    "tf.disable_v2_behavior()\n",
    "\n",
    "tf.enable_eager_execution()\n",
    "\n",
    "_GAME = pyspiel.load_game('kuhn_poker')\n",
    "_BATCH_SIZE = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tnsorflow_example(game_name, num_epochs, iterations):\n",
    "  game = pyspiel.load_game(game_name)\n",
    "\n",
    "  models = []\n",
    "  for _ in range(game.num_players()):\n",
    "    models.append(\n",
    "        rcfr_tf.DeepRcfrModel(\n",
    "            game,\n",
    "            num_hidden_layers=1,\n",
    "            num_hidden_units=13,\n",
    "            num_hidden_factors=8,\n",
    "            use_skip_connections=True))\n",
    "\n",
    "  buffer_size = -1\n",
    "  truncate_negative = False\n",
    "  bootstrap = False\n",
    "  if buffer_size > 0:\n",
    "    solver = rcfr_tf.ReservoirRcfrSolver(\n",
    "        game,\n",
    "        models,\n",
    "        buffer_size,\n",
    "        truncate_negative=truncate_negative)\n",
    "  else:\n",
    "    solver = rcfr_tf.RcfrSolver(\n",
    "        game,\n",
    "        models,\n",
    "        truncate_negative=truncate_negative,\n",
    "        bootstrap=bootstrap)\n",
    "\n",
    "  def _train_fn(model, data):\n",
    "    \"\"\"Train `model` on `data`.\"\"\"\n",
    "    batch_size = 100\n",
    "    step_size = 0.01\n",
    "    data = data.shuffle(batch_size * 10)\n",
    "    data = data.batch(batch_size)\n",
    "    data = data.repeat(num_epochs)\n",
    "\n",
    "    optimizer = tf.keras.optimizers.Adam(lr=step_size, amsgrad=True)\n",
    "\n",
    "    @tf.function\n",
    "    def _train():\n",
    "      for x, y in data:\n",
    "        optimizer.minimize(\n",
    "            lambda: tf.losses.huber_loss(y, model(x), delta=0.01),  # pylint: disable=cell-var-from-loop\n",
    "            model.trainable_variables)\n",
    "\n",
    "    _train()\n",
    "\n",
    "  # End of _train_fn\n",
    "  result = []\n",
    "  for i in range(iterations):\n",
    "    solver.evaluate_and_update_policy(_train_fn)\n",
    "    if i % 10 == 0:\n",
    "      conv = pyspiel.exploitability(game, solver.average_policy())\n",
    "      result.append(conv)\n",
    "      # print(\"Iteration {} exploitability {}\".format(i, conv))\n",
    "  return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pytorch_example(game_name, num_epochs, iterations):\n",
    "  game = pyspiel.load_game(game_name)\n",
    "\n",
    "  models = []\n",
    "  for _ in range(game.num_players()):\n",
    "    models.append(\n",
    "        rcfr_pt.DeepRcfrModel(\n",
    "            game,\n",
    "            num_hidden_layers=1,\n",
    "            num_hidden_units=13,\n",
    "            num_hidden_factors=8,\n",
    "            use_skip_connections=True))\n",
    "\n",
    "  buffer_size = -1\n",
    "  truncate_negative = False\n",
    "  bootstrap = False\n",
    "  if buffer_size > 0:\n",
    "    solver = rcfr_pt.ReservoirRcfrSolver(\n",
    "        game,\n",
    "        models,\n",
    "        buffer_size,\n",
    "        truncate_negative=truncate_negative)\n",
    "  else:\n",
    "    solver = rcfr_pt.RcfrSolver(\n",
    "        game,\n",
    "        models,\n",
    "        truncate_negative=truncate_negative,\n",
    "        bootstrap=bootstrap)\n",
    "\n",
    "  def _train_fn(model, data):\n",
    "    \"\"\"Train `model` on `data`.\"\"\"\n",
    "    batch_size = 100\n",
    "    num_epochs = 20\n",
    "    step_size = 0.01\n",
    "    \n",
    "    data = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n",
    "    loss_fn = nn.SmoothL1Loss()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=step_size, amsgrad=True)\n",
    "\n",
    "    def _train(model, data):\n",
    "      for epoch in range(num_epochs):\n",
    "        for x, y in data:\n",
    "          optimizer.zero_grad()\n",
    "          output = model(x)\n",
    "          loss = loss_fn(output, y)\n",
    "          loss.backward()\n",
    "          optimizer.step()\n",
    "\n",
    "    _train(model, data)\n",
    "\n",
    "  # End of _train_fn\n",
    "  result = []\n",
    "  for i in range(iterations):\n",
    "    solver.evaluate_and_update_policy(_train_fn)\n",
    "    if i % 10 == 0:\n",
    "      conv = pyspiel.exploitability(game, solver.average_policy())\n",
    "      result.append(conv)\n",
    "      # print(\"Iteration {} exploitability {}\".format(i, conv))\n",
    "  return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorflow_rcfr = []\n",
    "pytorch_rcfr = []\n",
    "num_epochs, iterations = 20, 100\n",
    "for _ in range(10):\n",
    "  tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n",
    "  pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = [i for i in range(10)]\n",
    "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
    "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
    "\n",
    "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
    "plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorflow_rcfr = []\n",
    "pytorch_rcfr = []\n",
    "num_epochs, iterations = 200, 100\n",
    "for _ in range(10):\n",
    "  tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n",
    "  pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = [i for i in range(10)]\n",
    "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
    "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
    "\n",
    "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
    "plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorflow_rcfr = []\n",
    "pytorch_rcfr = []\n",
    "num_epochs, iterations = 20, 100\n",
    "for _ in range(10):\n",
    "  tensorflow_rcfr.append(tnsorflow_example('leduc_poker', num_epochs, iterations))\n",
    "  pytorch_rcfr.append(pytorch_example('leduc_poker', num_epochs, iterations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "x = [i for i in range(10)]\n",
    "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
    "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
    "\n",
    "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
    "plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
